{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Co3D_Multiview\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.co3d_multiview import Co3d_Multiview\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "dataset = Co3d_Multiview(\n",
    "    split=\"train\", num_views=10, window_degree_range=360, num_samples_per_window=100, data_scaling=0.9, mask_bg='rand', ROOT=\"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed\", resolution=224, aug_crop=16,\n",
    ")\n",
    "\n",
    "# dataset = Co3d_Multiview(\n",
    "#     split=\"train\", num_views=40, window_degree_range=360, num_samples_per_window=1, mask_bg='rand', ROOT=\"/path/to/dust3r_data/co3d_all_seqs_per_category_subset_processed\", resolution=512, aug_crop=16,\n",
    "# )\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    assert len(views) == dataset.num_views\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx][\"camera_pose\"] for view_idx in range(dataset.num_views)]\n",
    "    cam_size = max(auto_cam_size(poses), 1)\n",
    "    for view_idx in range(dataset.num_views):\n",
    "        pts3d = views[view_idx][\"pts3d\"]\n",
    "        valid_mask = views[view_idx][\"valid_mask\"]\n",
    "        colors = rgb(views[view_idx][\"img\"])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(\n",
    "            pose_c2w=views[view_idx][\"camera_pose\"],\n",
    "            focal=views[view_idx][\"camera_intrinsics\"][0, 0],\n",
    "            color=(view_idx * 255, (1 - view_idx) * 255, 0),\n",
    "            image=colors,\n",
    "            cam_size=cam_size,\n",
    "        )\n",
    "    display(viz.show(point_size=100, viewer=\"notebook\"))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "views[0]['camera_pose']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "views[view_idx][\"img\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "from scipy.linalg import rq\n",
    "from fast3r.dust3r.datasets.co3d_multiview import Co3d_Multiview\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from IPython.display import display\n",
    "\n",
    "# Load dataset\n",
    "dataset = Co3d_Multiview(\n",
    "    split=\"train\", num_views=10, window_degree_range=360, num_samples_per_window=100, mask_bg='rand', \n",
    "    ROOT=\"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed\", resolution=224, aug_crop=16,\n",
    ")\n",
    "\n",
    "# Function to estimate the projection matrix using Direct Linear Transformation (DLT)\n",
    "def estimate_projection_matrix(world_points, image_points):\n",
    "    num_points = world_points.shape[0]\n",
    "    A = []\n",
    "\n",
    "    for i in range(num_points):\n",
    "        X, Y, Z = world_points[i]\n",
    "        u, v = image_points[i]\n",
    "        \n",
    "        A.append([-X, -Y, -Z, -1, 0, 0, 0, 0, u*X, u*Y, u*Z, u])\n",
    "        A.append([0, 0, 0, 0, -X, -Y, -Z, -1, v*X, v*Y, v*Z, v])\n",
    "    \n",
    "    A = np.array(A)\n",
    "    \n",
    "    # Solve using SVD (least squares solution)\n",
    "    U, S, Vh = np.linalg.svd(A)\n",
    "    P = Vh[-1, :].reshape(3, 4)\n",
    "    \n",
    "    return P\n",
    "\n",
    "# Function to decompose the projection matrix into intrinsic and extrinsic matrices\n",
    "def decompose_projection_matrix(P):\n",
    "    # Decompose P into K[R|t] using RQ decomposition\n",
    "    M = P[:, :3]\n",
    "    K, R = rq(M)\n",
    "    \n",
    "    # Normalize K to make sure the diagonal elements are positive\n",
    "    T = np.diag(np.sign(np.diag(K)))\n",
    "    K = K @ T\n",
    "    R = T @ R\n",
    "    \n",
    "    # Extract translation vector\n",
    "    t = np.linalg.inv(K) @ P[:, 3]\n",
    "    \n",
    "    return K, R, t\n",
    "\n",
    "# Function to plot the cameras as cones in 3D space based on the intrinsic matrix K\n",
    "def plot_camera_cones(fig, R, t, K, color='blue', scale=0.1):\n",
    "    \"\"\"\n",
    "    Plot the camera as a cone in 3D space based on the intrinsic matrix K for focal length.\n",
    "    \n",
    "    Parameters:\n",
    "    fig (plotly.graph_objects.Figure): The existing Plotly figure.\n",
    "    R (np.ndarray): The 3x3 rotation matrix.\n",
    "    t (np.ndarray): The 3x1 translation vector.\n",
    "    K (np.ndarray): The 3x3 intrinsic matrix.\n",
    "    color (str): Color of the camera cone.\n",
    "    scale (float): Scale factor for the size of the cone base.\n",
    "    \"\"\"\n",
    "    # The focal length is the element K[0, 0] (assuming fx and fy are equal)\n",
    "    focal_length = K[0, 0] / K[2, 2]\n",
    "\n",
    "    # The camera center (apex of the cone)\n",
    "    camera_center = -R.T @ t\n",
    "\n",
    "    # Define the orientation of the cone based on the inverse of the rotation matrix\n",
    "    direction = R.T @ np.array([0, 0, -1])  # Camera looks along the -Z axis in world space\n",
    "\n",
    "    # Scale the direction by the focal length\n",
    "    direction = direction * focal_length\n",
    "\n",
    "    # Plot the camera cone\n",
    "    fig.add_trace(go.Cone(\n",
    "        x=[camera_center[0]],\n",
    "        y=[camera_center[1]],\n",
    "        z=[camera_center[2]],\n",
    "        u=[direction[0]],\n",
    "        v=[direction[1]],\n",
    "        w=[direction[2]],\n",
    "        colorscale=[[0, color], [1, color]],  # Single color for the cone\n",
    "        showscale=False,\n",
    "        sizemode=\"absolute\",\n",
    "        sizeref=scale,  # The size of the cone base\n",
    "        anchor=\"tip\",  # The tip of the cone is the camera center\n",
    "        name=\"Camera Cone\"\n",
    "    ))\n",
    "\n",
    "# Function to visualize 3D points with RGB colors and estimated camera poses as cones using Plotly\n",
    "def plot_3d_scene_with_estimated_poses(points_list, colors_list, estimated_poses):\n",
    "    fig = go.Figure()\n",
    "\n",
    "    # Plot 3D points with RGB colors\n",
    "    for pts3d, colors in zip(points_list, colors_list):\n",
    "        x, y, z = pts3d[:, 0], pts3d[:, 1], pts3d[:, 2]\n",
    "        colors_rgb = colors.reshape(-1, 3)\n",
    "        fig.add_trace(go.Scatter3d(\n",
    "            x=x, y=y, z=z, mode='markers',\n",
    "            marker=dict(size=2, color=colors_rgb, colorscale=None, opacity=0.8),\n",
    "            name='3D Points'\n",
    "        ))\n",
    "\n",
    "    # Plot estimated camera cones\n",
    "    for idx, (R, t, K) in enumerate(estimated_poses):\n",
    "        plot_camera_cones(fig, R, t, K, color='blue', scale=5)\n",
    "\n",
    "    # Update layout for better visualization\n",
    "    fig.update_layout(\n",
    "        scene=dict(\n",
    "            xaxis_title='X',\n",
    "            yaxis_title='Y',\n",
    "            zaxis_title='Z',\n",
    "            aspectmode='data'\n",
    "        ),\n",
    "        margin=dict(r=0, l=0, b=0, t=0)\n",
    "    )\n",
    "    \n",
    "    fig.show()\n",
    "\n",
    "# Processing a single batch of views\n",
    "def process_views(N=5000):\n",
    "    for idx in np.random.permutation(len(dataset)):\n",
    "        views = dataset[idx]\n",
    "\n",
    "        # Collect all 3D points, RGB colors, and estimated poses for visualization\n",
    "        points_list = []\n",
    "        colors_list = []\n",
    "        estimated_poses = []\n",
    "\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]  # (224, 224, 3)\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]  # Only keep valid points\n",
    "\n",
    "            # Flatten the valid 3D points\n",
    "            pts3d = pts3d.reshape(-1, 3)\n",
    "            valid_mask_flat = valid_mask.flatten()\n",
    "            pts3d = pts3d[valid_mask_flat]\n",
    "\n",
    "            # Flatten the RGB image and apply the valid mask\n",
    "            img_rgb = rgb(views[view_idx][\"img\"]).reshape(-1, 3)\n",
    "            img_rgb = img_rgb[valid_mask_flat]\n",
    "\n",
    "            # Generate x and y coordinates for the image\n",
    "            x_coords = np.tile(np.arange(224), 224)\n",
    "            y_coords = np.repeat(np.arange(224), 224)\n",
    "            pixel_coords = np.stack((x_coords, y_coords), axis=1)\n",
    "            valid_pixel_coords = pixel_coords[valid_mask_flat]\n",
    "\n",
    "            # Sample N points to speed up estimation\n",
    "            if len(pts3d) > N:\n",
    "                sample_indices = np.random.choice(len(pts3d), N, replace=False)\n",
    "                pts3d = pts3d[sample_indices]\n",
    "                img_rgb = img_rgb[sample_indices]\n",
    "                valid_pixel_coords = valid_pixel_coords[sample_indices]\n",
    "\n",
    "            points_list.append(pts3d)\n",
    "            colors_list.append(img_rgb)\n",
    "\n",
    "            image_points = valid_pixel_coords  # Now image_points correspond to pts3d\n",
    "\n",
    "            # Estimate projection matrix for this view\n",
    "            P = estimate_projection_matrix(pts3d, image_points)\n",
    "            \n",
    "            # Decompose the projection matrix into intrinsic and extrinsic matrices\n",
    "            K, R, t = decompose_projection_matrix(P)\n",
    "\n",
    "            # Print the estimated K, R, and t\n",
    "            print(f\"View {view_idx} - Intrinsic matrix (K):\\n{K}\")\n",
    "            print(f\"View {view_idx} - Rotation matrix (R):\\n{R}\")\n",
    "            print(f\"View {view_idx} - Translation vector (t):\\n{t}\\n\")\n",
    "            \n",
    "            # Store the estimated rotation (R), translation (t), and intrinsic matrix (K)\n",
    "            estimated_poses.append((R, t, K))\n",
    "\n",
    "        # Plot the 3D scene with estimated camera cones\n",
    "        plot_3d_scene_with_estimated_poses(points_list, colors_list, estimated_poses)\n",
    "        \n",
    "        break  # Process one sample\n",
    "\n",
    "\n",
    "# Run the process with N point sampling\n",
    "process_views(N=10000)  # You can change N for faster/slower performance\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using Ground Truth Intrinsic Matrix + cv2.solvePnP\n",
    "\n",
    "import numpy as np\n",
    "import cv2  # OpenCV for solvePnP\n",
    "from fast3r.dust3r.datasets.co3d_multiview import Co3d_Multiview\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "# Load dataset\n",
    "dataset = Co3d_Multiview(\n",
    "    split=\"train\", num_views=10, window_degree_range=360, num_samples_per_window=100, mask_bg='rand', \n",
    "    ROOT=\"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed\", resolution=224, aug_crop=16,\n",
    ")\n",
    "\n",
    "# Function to convert estimated rotation and translation (R, t) into a camera pose (4x4 matrix)\n",
    "def Rt_to_pose(R, t):\n",
    "    \"\"\"Convert rotation matrix and translation vector to a 4x4 camera pose matrix.\"\"\"\n",
    "    pose = np.eye(4)\n",
    "    pose[:3, :3] = R\n",
    "    pose[:3, 3] = t[:, 0]  # Convert t from (3, 1) to (3,) shape\n",
    "    return pose\n",
    "\n",
    "# Function to invert a 4x4 pose matrix (world-to-camera to camera-to-world)\n",
    "def invert_pose(pose):\n",
    "    \"\"\"Invert a 4x4 pose matrix.\"\"\"\n",
    "    R_inv = pose[:3, :3].T  # Transpose the rotation part\n",
    "    t_inv = -R_inv @ pose[:3, 3]  # Invert the translation\n",
    "    pose_inv = np.eye(4)\n",
    "    pose_inv[:3, :3] = R_inv\n",
    "    pose_inv[:3, 3] = t_inv\n",
    "    return pose_inv\n",
    "\n",
    "# Processing a single batch of views\n",
    "def process_views(N=5000):\n",
    "    for idx in np.random.permutation(len(dataset)):\n",
    "        views = dataset[idx]\n",
    "        assert len(views) == dataset.num_views\n",
    "        print([view_name(view) for view in views])\n",
    "\n",
    "        # Initialize SceneViz for visualization\n",
    "        viz = SceneViz()\n",
    "        \n",
    "        # Estimate camera poses and set up visualization\n",
    "        points_list = []\n",
    "        colors_list = []\n",
    "        estimated_poses = []\n",
    "        poses_c2w = []  # List for the camera-to-world poses to visualize\n",
    "\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]  # (224, 224, 3)\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]  # Only keep valid points\n",
    "            img_rgb = rgb(views[view_idx][\"img\"])\n",
    "\n",
    "            # Flatten the valid 3D points\n",
    "            pts3d = pts3d.reshape(-1, 3)\n",
    "            valid_mask_flat = valid_mask.flatten()\n",
    "            pts3d = pts3d[valid_mask_flat]\n",
    "\n",
    "            # Flatten the RGB image and apply the valid mask\n",
    "            img_rgb = img_rgb.reshape(-1, 3)\n",
    "            img_rgb = img_rgb[valid_mask_flat]\n",
    "\n",
    "            # Generate x and y coordinates for the image\n",
    "            x_coords = np.tile(np.arange(224), 224)\n",
    "            y_coords = np.repeat(np.arange(224), 224)\n",
    "            pixel_coords = np.stack((x_coords, y_coords), axis=1)\n",
    "            valid_pixel_coords = pixel_coords[valid_mask_flat]\n",
    "\n",
    "            # Sample N points to speed up estimation\n",
    "            if len(pts3d) > N:\n",
    "                sample_indices = np.random.choice(len(pts3d), N, replace=False)\n",
    "                pts3d = pts3d[sample_indices]\n",
    "                img_rgb = img_rgb[sample_indices]\n",
    "                valid_pixel_coords = valid_pixel_coords[sample_indices]\n",
    "\n",
    "            points_list.append(pts3d)\n",
    "            colors_list.append(img_rgb)\n",
    "\n",
    "            image_points = valid_pixel_coords  # Now image_points correspond to pts3d\n",
    "\n",
    "            # Convert pts3d and image_points to float32\n",
    "            pts3d = pts3d.astype(np.float32)\n",
    "            image_points = image_points.astype(np.float32)\n",
    "\n",
    "            # Get intrinsic matrix from the dataset and ensure it's float32\n",
    "            K = np.array(views[view_idx][\"camera_intrinsics\"], dtype=np.float32)\n",
    "\n",
    "            # Check if we have at least 4 points\n",
    "            if len(pts3d) < 4 or len(image_points) < 4:\n",
    "                raise ValueError(\"Not enough points to run solvePnP. Need at least 4.\")\n",
    "\n",
    "            # Solve for the camera pose (R, t) using OpenCV's solvePnP\n",
    "            success, rvec, tvec = cv2.solvePnP(pts3d, image_points, K, None)\n",
    "            R, _ = cv2.Rodrigues(rvec)  # Convert rotation vector to matrix\n",
    "\n",
    "            # Convert (R, t) to world-to-camera pose matrix (4x4)\n",
    "            pose_w2c = Rt_to_pose(R, tvec)\n",
    "\n",
    "            # Invert the pose to get camera-to-world pose\n",
    "            pose_c2w = invert_pose(pose_w2c)\n",
    "            poses_c2w.append(pose_c2w)\n",
    "\n",
    "        # Use auto_cam_size to get the camera size for visualization\n",
    "        cam_size = max(auto_cam_size(poses_c2w), 1)\n",
    "\n",
    "        # Add the point clouds and estimated camera poses to the visualization\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]\n",
    "            colors = rgb(views[view_idx][\"img\"])\n",
    "\n",
    "            # Add the pointcloud to the visualization\n",
    "            viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "\n",
    "            # Add the estimated camera pose (camera-to-world matrix)\n",
    "            viz.add_camera(\n",
    "                pose_c2w=poses_c2w[view_idx],  # Use the inverted camera-to-world pose\n",
    "                focal=views[view_idx][\"camera_intrinsics\"][0, 0],\n",
    "                color=(view_idx * 255, (1 - view_idx) * 255, 0),\n",
    "                image=colors,\n",
    "                cam_size=cam_size,\n",
    "            )\n",
    "\n",
    "        # Show the visualization\n",
    "        display(viz.show(point_size=100, viewer=\"notebook\"))\n",
    "\n",
    "        break  # Process one sample\n",
    "\n",
    "\n",
    "# Run the process\n",
    "process_views(N=10000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using DLT to extimate Intrinsic Matrix\n",
    "\n",
    "import numpy as np\n",
    "from scipy.linalg import rq  # For RQ decomposition\n",
    "from fast3r.dust3r.datasets.co3d_multiview import Co3d_Multiview\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "# Load dataset\n",
    "dataset = Co3d_Multiview(\n",
    "    split=\"train\", num_views=10, window_degree_range=360, num_samples_per_window=100, mask_bg='rand', \n",
    "    ROOT=\"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed\", resolution=224, aug_crop=16,\n",
    ")\n",
    "\n",
    "# Function to estimate the projection matrix using Direct Linear Transformation (DLT)\n",
    "def estimate_projection_matrix(world_points, image_points):\n",
    "    num_points = world_points.shape[0]\n",
    "    A = []\n",
    "\n",
    "    for i in range(num_points):\n",
    "        X, Y, Z = world_points[i]\n",
    "        u, v = image_points[i]\n",
    "        \n",
    "        A.append([-X, -Y, -Z, -1, 0, 0, 0, 0, u*X, u*Y, u*Z, u])\n",
    "        A.append([0, 0, 0, 0, -X, -Y, -Z, -1, v*X, v*Y, v*Z, v])\n",
    "    \n",
    "    A = np.array(A)\n",
    "    \n",
    "    # Solve using SVD (least squares solution)\n",
    "    U, S, Vh = np.linalg.svd(A)\n",
    "    P = Vh[-1, :].reshape(3, 4)\n",
    "    \n",
    "    return P\n",
    "\n",
    "# Function to decompose the projection matrix into intrinsic and extrinsic matrices\n",
    "def decompose_projection_matrix(P):\n",
    "    \"\"\"Decompose the projection matrix P into intrinsic matrix K and extrinsic parameters (R, t).\"\"\"\n",
    "    # Decompose P into K[R|t] using RQ decomposition\n",
    "    M = P[:, :3]\n",
    "    K, R = rq(M)\n",
    "    \n",
    "    # Normalize K to make sure the diagonal elements are positive\n",
    "    T = np.diag(np.sign(np.diag(K)))\n",
    "    K = K @ T\n",
    "    R = T @ R\n",
    "    \n",
    "    # Extract translation vector\n",
    "    t = np.linalg.inv(K) @ P[:, 3]\n",
    "    \n",
    "    return K, R, t\n",
    "\n",
    "# Function to convert estimated rotation and translation (R, t) into a camera pose (4x4 matrix)\n",
    "def Rt_to_pose(R, t):\n",
    "    \"\"\"Convert rotation matrix and translation vector to a 4x4 camera pose matrix.\"\"\"\n",
    "    pose = np.eye(4)\n",
    "    pose[:3, :3] = R\n",
    "    pose[:3, 3] = t[:, 0]  # Convert t from (3, 1) to (3,) shape\n",
    "    return pose\n",
    "\n",
    "# Function to invert a 4x4 pose matrix (world-to-camera to camera-to-world)\n",
    "def invert_pose(pose):\n",
    "    \"\"\"Invert a 4x4 pose matrix.\"\"\"\n",
    "    R_inv = pose[:3, :3].T  # Transpose the rotation part\n",
    "    t_inv = -R_inv @ pose[:3, 3]  # Invert the translation\n",
    "    pose_inv = np.eye(4)\n",
    "    pose_inv[:3, :3] = R_inv\n",
    "    pose_inv[:3, 3] = t_inv\n",
    "    return pose_inv\n",
    "\n",
    "# Processing a single batch of views\n",
    "def process_views(N=5000):\n",
    "    for idx in np.random.permutation(len(dataset)):\n",
    "        views = dataset[idx]\n",
    "        assert len(views) == dataset.num_views\n",
    "        print([view_name(view) for view in views])\n",
    "\n",
    "        # Initialize SceneViz for visualization\n",
    "        viz = SceneViz()\n",
    "        \n",
    "        # Estimate camera poses and intrinsics, and set up visualization\n",
    "        points_list = []\n",
    "        colors_list = []\n",
    "        estimated_poses = []\n",
    "        poses_c2w = []  # List for the camera-to-world poses to visualize\n",
    "\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]  # (224, 224, 3)\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]  # Only keep valid points\n",
    "            img_rgb = rgb(views[view_idx][\"img\"])\n",
    "\n",
    "            # Flatten the valid 3D points\n",
    "            pts3d = pts3d.reshape(-1, 3)\n",
    "            valid_mask_flat = valid_mask.flatten()\n",
    "            pts3d = pts3d[valid_mask_flat]\n",
    "\n",
    "            # Flatten the RGB image and apply the valid mask\n",
    "            img_rgb = img_rgb.reshape(-1, 3)\n",
    "            img_rgb = img_rgb[valid_mask_flat]\n",
    "\n",
    "            # Generate x and y coordinates for the image\n",
    "            x_coords = np.tile(np.arange(224), 224)\n",
    "            y_coords = np.repeat(np.arange(224), 224)\n",
    "            pixel_coords = np.stack((x_coords, y_coords), axis=1)\n",
    "            valid_pixel_coords = pixel_coords[valid_mask_flat]\n",
    "\n",
    "            # Sample N points to speed up estimation\n",
    "            if len(pts3d) > N:\n",
    "                sample_indices = np.random.choice(len(pts3d), N, replace=False)\n",
    "                pts3d = pts3d[sample_indices]\n",
    "                img_rgb = img_rgb[sample_indices]\n",
    "                valid_pixel_coords = valid_pixel_coords[sample_indices]\n",
    "\n",
    "            points_list.append(pts3d)\n",
    "            colors_list.append(img_rgb)\n",
    "\n",
    "            image_points = valid_pixel_coords  # Now image_points correspond to pts3d\n",
    "\n",
    "            # Convert pts3d and image_points to float32\n",
    "            pts3d = pts3d.astype(np.float32)\n",
    "            image_points = image_points.astype(np.float32)\n",
    "\n",
    "            # Estimate the projection matrix using DLT\n",
    "            P = estimate_projection_matrix(pts3d, image_points)\n",
    "\n",
    "            # Decompose the projection matrix into intrinsics and extrinsics\n",
    "            K, R, t = decompose_projection_matrix(P)\n",
    "\n",
    "            # Print the estimated intrinsics and extrinsics\n",
    "            print(f\"View {view_idx} - Estimated Intrinsic matrix (K):\\n{K}\")\n",
    "            print(f\"View {view_idx} - Estimated Rotation matrix (R):\\n{R}\")\n",
    "            print(f\"View {view_idx} - Estimated Translation vector (t):\\n{t}\\n\")\n",
    "\n",
    "            # Convert (R, t) to world-to-camera pose matrix (4x4)\n",
    "            pose_w2c = Rt_to_pose(R, t.reshape(-1, 1))\n",
    "\n",
    "            # Invert the pose to get camera-to-world pose\n",
    "            pose_c2w = invert_pose(pose_w2c)\n",
    "            poses_c2w.append(pose_c2w)\n",
    "\n",
    "        # Use auto_cam_size to get the camera size for visualization\n",
    "        cam_size = max(auto_cam_size(poses_c2w), 1)\n",
    "\n",
    "        # Add the point clouds and estimated camera poses to the visualization\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]\n",
    "            colors = rgb(views[view_idx][\"img\"])\n",
    "\n",
    "            # Add the pointcloud to the visualization\n",
    "            viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "\n",
    "            # Add the estimated camera pose (camera-to-world matrix)\n",
    "            viz.add_camera(\n",
    "                pose_c2w=poses_c2w[view_idx],  # Use the inverted camera-to-world pose\n",
    "                focal=K[0, 0] / K[2, 2],  # Use the estimated focal length from K\n",
    "                # focal=None,\n",
    "                color=(view_idx * 255, (1 - view_idx) * 255, 0),\n",
    "                image=colors,\n",
    "                cam_size=cam_size,\n",
    "            )\n",
    "\n",
    "        # Show the visualization\n",
    "        display(viz.show(point_size=100, viewer=\"notebook\"))\n",
    "\n",
    "        break  # Process one sample\n",
    "\n",
    "\n",
    "# Run the process\n",
    "process_views(N=10000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Guess focal length and use cv2.solvePnPRansac to solve for extrinsics\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import cv2\n",
    "from fast3r.dust3r.datasets.co3d_multiview import Co3d_Multiview\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "from fast3r.dust3r.cloud_opt.init_im_poses import fast_pnp  # Import fast_pnp\n",
    "\n",
    "# Load dataset\n",
    "dataset = Co3d_Multiview(\n",
    "    split=\"train\", num_views=2, window_degree_range=360, num_samples_per_window=100, mask_bg='rand', \n",
    "    ROOT=\"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed\", resolution=224, aug_crop=16,\n",
    ")\n",
    "\n",
    "# Function to process views and estimate camera poses using fast_pnp\n",
    "def process_views_with_fast_pnp(niter_PnP=10):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    \n",
    "    for idx in np.random.permutation(len(dataset)):\n",
    "        views = dataset[idx]\n",
    "        assert len(views) == dataset.num_views\n",
    "        print([view_name(view) for view in views])\n",
    "\n",
    "        # Initialize SceneViz for visualization\n",
    "        viz = SceneViz()\n",
    "        \n",
    "        # Estimate camera poses and focal lengths, and set up visualization\n",
    "        points_list = []\n",
    "        colors_list = []\n",
    "        estimated_poses = []\n",
    "        estimated_focals = []  # List for the guessed focal lengths\n",
    "        poses_c2w = []  # List for the camera-to-world poses to visualize\n",
    "\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]  # (224, 224, 3) shape\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]  # (224, 224) mask\n",
    "            img_rgb = rgb(views[view_idx][\"img\"])\n",
    "\n",
    "            # Do not flatten pts3d or valid_mask here for fast_pnp\n",
    "            points_list.append(pts3d)\n",
    "            colors_list.append(img_rgb)\n",
    "\n",
    "            # Call fast_pnp with unflattened pts3d and mask\n",
    "            focal_length, pose_c2w = fast_pnp(\n",
    "                torch.tensor(pts3d, device=device),  # Pass original unmasked pts3d\n",
    "                None,  # Guess focal length\n",
    "                torch.tensor(valid_mask, device=device, dtype=torch.bool),  # Valid mask (unflattened)\n",
    "                device,\n",
    "                pp=None,  # Use default principal point (center of image)\n",
    "                niter_PnP=niter_PnP\n",
    "            )\n",
    "\n",
    "            if pose_c2w is None:\n",
    "                print(f\"Failed to estimate pose for view {view_idx}\")\n",
    "                continue\n",
    "\n",
    "            # Store the estimated camera-to-world pose and focal length\n",
    "            poses_c2w.append(pose_c2w.cpu().numpy())\n",
    "            estimated_focals.append(focal_length)\n",
    "            print(f\"View {view_idx} - Estimated Focal Length: {focal_length}\")\n",
    "\n",
    "        # Use auto_cam_size to get the camera size for visualization\n",
    "        cam_size = max(auto_cam_size(poses_c2w), 1)\n",
    "\n",
    "        # Add the point clouds and estimated camera poses to the visualization\n",
    "        for view_idx in range(dataset.num_views):\n",
    "            pts3d = views[view_idx][\"pts3d\"]\n",
    "            valid_mask = views[view_idx][\"valid_mask\"]\n",
    "            colors = rgb(views[view_idx][\"img\"])\n",
    "\n",
    "            # Add the pointcloud to the visualization\n",
    "            viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "\n",
    "            # Add the estimated camera pose (camera-to-world matrix) and focal length\n",
    "            viz.add_camera(\n",
    "                pose_c2w=poses_c2w[view_idx],  # Use the estimated camera-to-world pose\n",
    "                focal=estimated_focals[view_idx],  # Use the estimated focal length for each view\n",
    "                color=np.random.randint(0, 256, size=3),  # Generate a random RGB color\n",
    "                image=colors,\n",
    "                cam_size=cam_size,\n",
    "            )\n",
    "\n",
    "        # Show the visualization\n",
    "        display(viz.show(point_size=100, viewer=\"notebook\"))\n",
    "\n",
    "        break  # Process one sample\n",
    "\n",
    "\n",
    "# Run the process using fast_pnp\n",
    "process_views_with_fast_pnp()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import fast3r\n",
    "from fast3r.dust3r.viz_plotly import SceneViz  # Import Plotly version for visualization\n",
    "\n",
    "import importlib\n",
    "importlib.reload(fast3r.dust3r.viz_plotly)\n",
    "\n",
    "# Load dataset\n",
    "dataset = Co3d_Multiview(\n",
    "    split=\"train\", num_views=2, window_degree_range=360, num_samples_per_window=100, mask_bg='rand', \n",
    "    ROOT=\"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed\", resolution=224, aug_crop=16,\n",
    ")\n",
    "\n",
    "# Run the process using fast_pnp\n",
    "process_views_with_fast_pnp()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.utils import shuffle\n",
    "from PIL import Image\n",
    "\n",
    "def image2zvals(img, n_colors=4, n_training_pixels=1000, random_seed=42):\n",
    "    \"\"\"Perform color quantization on the image using K-means clustering.\"\"\"\n",
    "    if img.ndim != 3:\n",
    "        raise ValueError(f\"Your image does not appear to be a color image. Its shape is {img.shape}\")\n",
    "    \n",
    "    rows, cols, d = img.shape\n",
    "    if d < 3:\n",
    "        raise ValueError(f\"A color image should have the shape (m, n, d), d=3 or 4. Your d={d}\")\n",
    "    \n",
    "    if img.max() > 1:\n",
    "        img = np.clip(img / 255.0, 0, 1)\n",
    "\n",
    "    observations = img[:, :, :3].reshape(rows * cols, 3)\n",
    "    training_pixels = shuffle(observations, random_state=random_seed)[:n_training_pixels]\n",
    "    \n",
    "    kmeans = KMeans(n_clusters=n_colors, random_state=random_seed).fit(training_pixels)\n",
    "    codebook = kmeans.cluster_centers_\n",
    "    indices = kmeans.predict(observations)\n",
    "    \n",
    "    z_vals = indices.astype(float) / (n_colors - 1)  # Normalize to [0, 1]\n",
    "    z_vals = z_vals.reshape(rows, cols)\n",
    "\n",
    "    # Generate the colorscale for Plotly\n",
    "    scale = np.linspace(0, 1, n_colors)\n",
    "    colors = (codebook * 255).astype(np.uint8)\n",
    "    plotly_colorscale = [[s, f'rgb{tuple(c)}'] for s, c in zip(scale, colors)]\n",
    "    \n",
    "    return z_vals, plotly_colorscale\n",
    "\n",
    "def plot_quantized_heatmap(image):\n",
    "    \"\"\"Plot the quantized image as a 2D heatmap to debug the color quantization.\"\"\"\n",
    "    z_vals, pl_colorscale = image2zvals(image)\n",
    "\n",
    "    fig = go.Figure(data=go.Heatmap(\n",
    "        z=z_vals, \n",
    "        colorscale=pl_colorscale,\n",
    "        showscale=False\n",
    "    ))\n",
    "\n",
    "    fig.update_layout(\n",
    "        title=\"Quantized Image Heatmap\",\n",
    "        xaxis=dict(visible=False),\n",
    "        yaxis=dict(visible=False),\n",
    "        width=600,\n",
    "        height=600\n",
    "    )\n",
    "    \n",
    "    fig.show()\n",
    "\n",
    "# Example Usage\n",
    "image_path = '/path/to/beef_jerky/IMG_0050.jpg'\n",
    "image = np.array(Image.open(image_path))\n",
    "\n",
    "plot_quantized_heatmap(image)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.utils import shuffle\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "# Function to create a surface (modified from the blog)\n",
    "def surface(rows, cols):\n",
    "    \"\"\"Generate a surface with a sine and cosine wave for testing.\"\"\"\n",
    "    x = np.linspace(-np.pi, np.pi, cols)\n",
    "    y = np.linspace(-np.pi, np.pi, rows)\n",
    "    x, y = np.meshgrid(x, y)\n",
    "    z = 0.5 * np.cos(x / 2) + 0.2 * np.sin(y / 4)\n",
    "    return x, y, z\n",
    "\n",
    "# Helper function to quantize the image using K-means\n",
    "def image2zvals(img, n_colors=64, n_training_pixels=10000, rngs=123):\n",
    "    \"\"\"Quantize the image to n_colors using KMeans.\"\"\"\n",
    "    rows, cols, _ = img.shape\n",
    "\n",
    "    # Normalize the image if necessary\n",
    "    if img.max() > 1:\n",
    "        img = np.clip(img / 255.0, 0, 1)\n",
    "\n",
    "    observations = img[:, :, :3].reshape(rows * cols, 3)\n",
    "    training_pixels = shuffle(observations, random_state=rngs)[:n_training_pixels]\n",
    "\n",
    "    kmeans = KMeans(n_clusters=n_colors, random_state=rngs).fit(training_pixels)\n",
    "    codebook = kmeans.cluster_centers_\n",
    "    indices = kmeans.predict(observations)\n",
    "\n",
    "    z_vals = indices.astype(float) / (n_colors - 1)  # Normalize to [0, 1]\n",
    "    z_vals = z_vals.reshape(rows, cols)\n",
    "\n",
    "    # Generate the Plotly colorscale\n",
    "    scale = np.linspace(0, 1, n_colors)\n",
    "    colors = (codebook * 255).astype(np.uint8)\n",
    "    plotly_colorscale = [[s, f'rgb{tuple(c)}'] for s, c in zip(scale, colors)]\n",
    "\n",
    "    return z_vals, plotly_colorscale\n",
    "\n",
    "# Generate triangles for the mesh\n",
    "def regular_triangles(rows, cols):\n",
    "    \"\"\"Generate regular triangles for a mesh.\"\"\"\n",
    "    triangles = []\n",
    "    for i in range(rows - 1):\n",
    "        for j in range(cols - 1):\n",
    "            k = j + i * cols\n",
    "            triangles.extend([[k, k + cols, k + 1 + cols], [k, k + 1 + cols, k + 1]])\n",
    "    return np.array(triangles)\n",
    "\n",
    "# Create mesh data for texture mapping\n",
    "def mesh_data(img, n_colors=32, n_training_pixels=1000):\n",
    "    \"\"\"Generate mesh data with quantized color intensities for the image.\"\"\"\n",
    "    # Quantize the downsampled image\n",
    "    z_vals, pl_colorscale = image2zvals(img, n_colors=n_colors, n_training_pixels=n_training_pixels)\n",
    "\n",
    "    # Generate triangles\n",
    "    rows, cols, _ = img.shape\n",
    "    triangles = regular_triangles(rows, cols)\n",
    "    I, J, K = triangles.T\n",
    "\n",
    "    # Assign intensity to each triangle\n",
    "    zc = z_vals.flatten()[triangles]\n",
    "    tri_color_intensity = [zc[k][2] if k % 2 else zc[k][1] for k in range(len(zc))]\n",
    "\n",
    "    return I, J, K, tri_color_intensity, pl_colorscale\n",
    "\n",
    "# Function to downsample the image and create the 3D Mesh3d object for plotting\n",
    "def create_mesh3d(img, resolution=64, n_colors=256, view_idx=0):\n",
    "    \"\"\"Creates a Mesh3d object for the image texture mapping with downsampled image.\"\"\"\n",
    "    # Downsample the image first\n",
    "    img_downsampled = np.array(Image.fromarray(img).resize((resolution, resolution)))\n",
    "\n",
    "    # Generate the surface mesh based on downsampled resolution\n",
    "    rows, cols, _ = img_downsampled.shape\n",
    "    x, y, z = surface(rows, cols)\n",
    "\n",
    "    # Get the mesh data\n",
    "    I, J, K, tri_color_intensity, pl_colorscale = mesh_data(img_downsampled, n_colors=n_colors)\n",
    "\n",
    "    # Create the Mesh3d trace\n",
    "    mesh3d_trace = go.Mesh3d(\n",
    "        x=x.flatten(), y=np.flipud(y).flatten(), z=z.flatten() + view_idx,  # Offset z for different views\n",
    "        i=I, j=J, k=K,\n",
    "        intensity=tri_color_intensity,\n",
    "        intensitymode=\"cell\",\n",
    "        colorscale=pl_colorscale,\n",
    "        showscale=False,\n",
    "        name=f\"Image {view_idx}\"\n",
    "    )\n",
    "\n",
    "    return mesh3d_trace\n",
    "\n",
    "# Load two images for testing\n",
    "image1_path = '/path/to/beef_jerky/IMG_0050.jpg'\n",
    "image2_path = '/path/to/beef_jerky/IMG_0051.jpg'\n",
    "\n",
    "image1 = np.array(Image.open(image1_path))\n",
    "image2 = np.array(Image.open(image2_path))\n",
    "\n",
    "# Test with one surface using Mesh3d\n",
    "fig1 = go.Figure()\n",
    "\n",
    "# Add the first mesh with the first image\n",
    "mesh1 = create_mesh3d(image1, resolution=256, n_colors=128, view_idx=0)\n",
    "fig1.add_trace(mesh1)\n",
    "\n",
    "fig1.update_layout(\n",
    "    title=\"One Surface Mesh3d Test (Using Surface Mesh and Downsampled Image)\",\n",
    "    scene=dict(aspectmode='data')\n",
    ")\n",
    "\n",
    "fig1.show()\n",
    "\n",
    "# Test with two surfaces using Mesh3d\n",
    "fig2 = go.Figure()\n",
    "\n",
    "# Add the first mesh with the first image\n",
    "mesh1 = create_mesh3d(image1, resolution=128, n_colors=256, view_idx=0)\n",
    "fig2.add_trace(mesh1)\n",
    "\n",
    "# Add the second mesh with the second image\n",
    "mesh2 = create_mesh3d(image2, resolution=64, n_colors=256, view_idx=1)\n",
    "fig2.add_trace(mesh2)\n",
    "\n",
    "fig2.update_layout(\n",
    "    title=\"Two Surface Mesh3d Test (Using Surface Mesh and Downsampled Image)\",\n",
    "    scene=dict(aspectmode='data')\n",
    ")\n",
    "\n",
    "fig2.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "from PIL import Image\n",
    "\n",
    "def create_camera_frustum_with_image(pose_c2w, focal, H, W, image=None, scale=0.05, color='blue', resolution=64):\n",
    "    # Create frustum points in camera space\n",
    "    depth = focal * scale\n",
    "    hw_ratio = W / H\n",
    "    \n",
    "    frustum_points = np.array([\n",
    "        [0, 0, 0],  # Camera origin\n",
    "        [-hw_ratio * depth, -depth, depth],  # Bottom left corner of the frustum\n",
    "        [hw_ratio * depth, -depth, depth],  # Bottom right corner\n",
    "        [hw_ratio * depth, depth, depth],  # Top right corner\n",
    "        [-hw_ratio * depth, depth, depth],  # Top left corner\n",
    "    ])\n",
    "    \n",
    "    # Transform frustum points to world coordinates\n",
    "    frustum_points_homogeneous = np.hstack([frustum_points, np.ones((frustum_points.shape[0], 1))])  # Homogeneous coordinates\n",
    "    frustum_points_world = (pose_c2w @ frustum_points_homogeneous.T).T[:, :3]  # Apply pose transformation\n",
    "\n",
    "    # Frustum lines (edges of the pyramid)\n",
    "    edges = [\n",
    "        (0, 1), (0, 2), (0, 3), (0, 4),  # From camera to corners of the image plane\n",
    "        (1, 2), (2, 3), (3, 4), (4, 1)   # Edges of the image plane\n",
    "    ]\n",
    "    \n",
    "    # Combine all edges into one trace\n",
    "    x_vals, y_vals, z_vals = [], [], []\n",
    "    for edge in edges:\n",
    "        x_vals += [frustum_points_world[edge[0], 0], frustum_points_world[edge[1], 0], None]  # Add None to break the line\n",
    "        y_vals += [frustum_points_world[edge[0], 1], frustum_points_world[edge[1], 1], None]\n",
    "        z_vals += [frustum_points_world[edge[0], 2], frustum_points_world[edge[1], 2], None]\n",
    "    \n",
    "    frustum_trace = go.Scatter3d(\n",
    "        x=x_vals,\n",
    "        y=y_vals,\n",
    "        z=z_vals,\n",
    "        mode='lines',\n",
    "        line=dict(color=color),\n",
    "        name=\"Camera Frustum\",\n",
    "        legendgroup=f\"frustum_{id(pose_c2w)}\",\n",
    "        showlegend=True\n",
    "    )\n",
    "\n",
    "    # Add image to the base of the frustum if available\n",
    "    image_surface_trace = None\n",
    "    if image is not None:\n",
    "        # Downsample the image to finer resolution for better color mapping\n",
    "        img = np.array(image.resize((resolution, resolution)))  # Resize for faster processing\n",
    "        H_img, W_img, _ = img.shape\n",
    "\n",
    "        # Create mesh grid on the base of the frustum\n",
    "        u = np.linspace(0, 1, W_img)\n",
    "        v = np.linspace(0, 1, H_img)\n",
    "        uu, vv = np.meshgrid(u, v)\n",
    "\n",
    "        # Bottom rectangle vertices of the frustum (for image mapping)\n",
    "        img_vertices = frustum_points_world[1:5]  # Bottom rectangle of the frustum\n",
    "        img_x, img_y, img_z = img_vertices[:, 0], img_vertices[:, 1], img_vertices[:, 2]\n",
    "\n",
    "        # Bilinearly interpolate to create a fine grid for the image mapping\n",
    "        X = img_x[0] * (1 - uu) * (1 - vv) + img_x[1] * uu * (1 - vv) + img_x[3] * (1 - uu) * vv + img_x[2] * uu * vv\n",
    "        Y = img_y[0] * (1 - uu) * (1 - vv) + img_y[1] * uu * (1 - vv) + img_y[3] * (1 - uu) * vv + img_y[2] * uu * vv\n",
    "        Z = img_z[0] * (1 - uu) * (1 - vv) + img_z[1] * uu * (1 - vv) + img_z[3] * (1 - uu) * vv + img_z[2] * uu * vv\n",
    "\n",
    "        # Compute grayscale intensity (average of RGB channels)\n",
    "        grayscale_img = np.mean(img, axis=-1) / 255.0  # Normalize to [0, 1]\n",
    "\n",
    "        # Create surface trace for the grayscale image\n",
    "        image_surface_trace = go.Surface(\n",
    "            x=X,\n",
    "            y=Y,\n",
    "            z=Z,\n",
    "            surfacecolor=grayscale_img,  # Use the grayscale image\n",
    "            colorscale='gray',  # Grayscale color scale\n",
    "            showscale=False,\n",
    "            name=\"Camera Frustum Image\",\n",
    "            legendgroup=f\"frustum_{id(pose_c2w)}\",  # Link with the frustum lines\n",
    "            showlegend=False  # Hide separate legend, it's linked to the frustum lines\n",
    "        )\n",
    "\n",
    "    return [frustum_trace, image_surface_trace] if image_surface_trace else [frustum_trace]\n",
    "\n",
    "def plot_cameras(camera_poses, focals, H, W, images=None, scale=0.05, resolution=64):\n",
    "    fig = go.Figure()\n",
    "\n",
    "    # Add camera frustums to the plot\n",
    "    for i, (pose_c2w, focal) in enumerate(zip(camera_poses, focals)):\n",
    "        image = images[i] if images is not None else None\n",
    "        frustum_traces = create_camera_frustum_with_image(pose_c2w, focal, H, W, image, scale, color=f'rgb({50*i}, {100}, {150})', resolution=resolution)\n",
    "        for trace in frustum_traces:\n",
    "            if trace is not None:\n",
    "                fig.add_trace(trace)\n",
    "\n",
    "    # Set 3D aspect ratio and layout\n",
    "    fig.update_layout(scene=dict(aspectmode='data'),\n",
    "                      scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),\n",
    "                      title=\"Camera Poses and Frustums with Images\")\n",
    "    \n",
    "    fig.show()\n",
    "\n",
    "# Example usage with camera poses (4x4 matrices) and focals\n",
    "camera_poses = [\n",
    "    np.eye(4),  # Identity pose for the first camera\n",
    "    np.array([[1, 0, 0, 0.5], [0, 1, 0, 0.5], [0, 0, 1, 0.5], [0, 0, 0, 1]])  # Example pose for the second camera\n",
    "]\n",
    "focals = [20, 500]  # Focal lengths of the cameras\n",
    "H, W = 1080, 1920  # Example image dimensions\n",
    "\n",
    "# Load example images (replace with real images)\n",
    "image1 = Image.open('/path/to/beef_jerky/IMG_0050.jpg')\n",
    "image2 = Image.open('/path/to/beef_jerky/IMG_0050.jpg')\n",
    "\n",
    "plot_cameras(camera_poses, focals, H, W, images=[image1, image2], scale=0.1, resolution=128)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ScanNet++"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "##### ScanNetpp_Multiview\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import itertools\n",
    "import json\n",
    "import os.path as osp\n",
    "from collections import deque\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.scannetpp_multiview import ScanNetpp_Multiview\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "\n",
    "dataset = ScanNetpp_Multiview(num_views=8, data_scaling=0.5, window_size=10, num_samples_per_window=1, split='train', ordered=True, ROOT=\"/path/to/dust3r_data/scannetpp_processed\", resolution=512, aug_crop=16)\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    # views = dataset[idx]\n",
    "    views = dataset[-1]\n",
    "    assert len(views) == dataset.num_views\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "    cam_size = max(auto_cam_size(poses), 1)\n",
    "    for view_idx in range(dataset.num_views):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(view_idx*255, (1 - view_idx)*255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show(point_size=100, viewer=\"notebook\"))\n",
    "    break\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 80_000 @ dataset\n",
    "dataset.set_epoch(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "dataset[100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "views = dataset[1005]\n",
    "assert len(views) == dataset.num_views\n",
    "print([view_name(view) for view in views])\n",
    "viz = SceneViz()\n",
    "poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "cam_size = max(auto_cam_size(poses), 1)\n",
    "for view_idx in range(dataset.num_views):\n",
    "    pts3d = views[view_idx]['pts3d']\n",
    "    valid_mask = views[view_idx]['valid_mask']\n",
    "    colors = rgb(views[view_idx]['img'])\n",
    "    viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "    viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                    focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                    color=(idx*255, (1 - idx)*255, 0),\n",
    "                    image=colors,\n",
    "                    cam_size=cam_size)\n",
    "display(viz.show(point_size=100, viewer=\"notebook\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### MegaDepth\n",
    "\n",
    "import itertools\n",
    "import json\n",
    "import os.path as osp\n",
    "from collections import deque\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# from dust3r.datasets.megadepth_multiview import MegaDepth_Multiview\n",
    "from dust3r.datasets.megadepth import MegaDepth\n",
    "\n",
    "from dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from dust3r.utils.image import rgb\n",
    "from dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "\n",
    "# dataset = MegaDepth_Multiview(split='train', num_views=4, window_size=60, num_samples_per_window=100, ROOT=\"/path/to/dust3r_data/megadepth_processed\", resolution=512, aug_crop=16)\n",
    "dataset = MegaDepth(split='train', ROOT=\"/path/to/dust3r_data/megadepth_processed\", resolution=512, aug_crop=16)\n",
    "\n",
    "views = dataset[0]\n",
    "assert len(views) == dataset.num_views\n",
    "print([view_name(view) for view in views])\n",
    "viz = SceneViz()\n",
    "poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "cam_size = max(auto_cam_size(poses), 1)\n",
    "for view_idx in range(dataset.num_views):\n",
    "    pts3d = views[view_idx]['pts3d']\n",
    "    valid_mask = views[view_idx]['valid_mask']\n",
    "    colors = rgb(views[view_idx]['img'])\n",
    "    viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "    viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                    focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                    color=(view_idx*255, (1 - view_idx)*255, 0),\n",
    "                    image=colors,\n",
    "                    cam_size=cam_size)\n",
    "display(viz.show(point_size=100, viewer=\"notebook\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### MegaDepth_Multiview\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import itertools\n",
    "import json\n",
    "import os.path as osp\n",
    "from collections import deque\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.megadepth_multiview import MegaDepth_Multiview\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "\n",
    "# dataset = MegaDepth_Multiview(split='train', num_views=20, window_size=40, num_samples_per_window=1, ROOT=\"/path/to/dust3r_data/megadepth_processed\", resolution=512, aug_crop=16)\n",
    "dataset = 100 @ MegaDepth_Multiview(split='val', num_views=12, window_size=24, num_samples_per_window=100, ROOT=\"/path/to/dust3r_data/megadepth_processed\", resolution=(512, 336), seed=777)\n",
    "dataset.set_epoch(0)\n",
    "print(dataset)\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    # views = dataset[-1]\n",
    "    # assert len(views) == dataset.num_views\n",
    "    # print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in range(len(views))]\n",
    "    cam_size = max(auto_cam_size(poses), 1)\n",
    "    for view_idx in range(len(views)):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        # focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(view_idx*255, (1 - view_idx)*255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show(point_size=100, viewer=\"notebook\"))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize the rgb\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Load the images from the views\n",
    "views = dataset[89]\n",
    "images = [rgb(view['img']) for view in views]\n",
    "\n",
    "# Plot the images\n",
    "fig, axes = plt.subplots(1, len(images), figsize=(40, 8))\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(images[i])\n",
    "    ax.axis('off')\n",
    "    \n",
    "\n",
    "plt.show()\n",
    "\n",
    "# show the valid mask\n",
    "# Load the images from the views\n",
    "fig, axes = plt.subplots(1, len(images), figsize=(40, 8))\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(views[i]['valid_mask'])\n",
    "    ax.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "views[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    # views = dataset[-1]\n",
    "    assert len(views) == dataset.num_views\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "    cam_size = max(auto_cam_size(poses), 1)\n",
    "    for view_idx in range(dataset.num_views):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(view_idx*255, (1 - view_idx)*255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    # display(viz.show(point_size=100, viewer=\"notebook\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ArkitScenes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ArkitScenes_Multiview\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.arkitscenes_multiview import ARKitScenes_Multiview\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "dataset = ARKitScenes_Multiview(\n",
    "    split='train', data_scaling=0.5, num_views=20, window_size=30, num_samples_per_window=2, ROOT=\"/path/to/dust3r_data/arkitscenes_processed\", resolution=(512,100), aug_crop=256\n",
    ")\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    assert len(views) == dataset.num_views\n",
    "    print(dataset.num_views)\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "    cam_size = max(auto_cam_size(poses), 0.2)\n",
    "    for view_idx in range(dataset.num_views):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        # focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(view_idx * 255, (1 - view_idx) * 255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "views = dataset[100]\n",
    "assert len(views) == dataset.num_views\n",
    "print(dataset.num_views)\n",
    "print([view_name(view) for view in views])\n",
    "viz = SceneViz()\n",
    "poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "cam_size = max(auto_cam_size(poses), 0.001)\n",
    "for view_idx in range(dataset.num_views):\n",
    "    pts3d = views[view_idx]['pts3d']\n",
    "    valid_mask = views[view_idx]['valid_mask']\n",
    "    colors = rgb(views[view_idx]['img'])\n",
    "    viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "    viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                    focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                    color=(view_idx * 255, (1 - view_idx) * 255, 0),\n",
    "                    image=colors,\n",
    "                    cam_size=cam_size)\n",
    "display(viz.show())\n",
    "break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize the images from views\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Load the images from the views\n",
    "images = [rgb(view['img']) for view in views]\n",
    "\n",
    "# Plot the images\n",
    "fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(images[i])\n",
    "    ax.axis('off')\n",
    "    ax.set_title(f\"View {i}\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Habitat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Habitat\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.habitat import Habitat\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "dataset = Habitat(1_000, split='train', ROOT=\"/path/to/dust3r_data/habitat_processed\",\n",
    "                    resolution=224, aug_crop=16)\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    assert len(views) == 2\n",
    "    print(view_name(views[0]), view_name(views[1]))\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n",
    "    cam_size = max(auto_cam_size(poses), 0.001)\n",
    "    for view_idx in [0, 1]:\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(idx * 255, (1 - idx) * 255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Habitat_Multiview\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.habitat_multiview import Habitat_Multiview\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "dataset = Habitat_Multiview(1_000, data_scaling=0.5, split='train', num_views=12, ROOT=\"/path/to/dust3r_data/habitat_processed\", aug_crop=16, resolution=512)\n",
    "# dataset = Habitat_Multiview(1_000_000, split='train', num_views=4, ROOT='/path/to/dust3r_data/habitat_processed', aug_crop=16, resolution=(512,384))\n",
    "# dataset = 100 @ Habitat_Multiview(100000, split='val', num_views=12, ROOT=\"/path/to/dust3r_data/habitat_processed\", resolution=(512,384), seed=777)\n",
    "dataset.set_epoch(0)\n",
    "print(len(dataset))\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    assert len(views) == dataset.num_views\n",
    "    print(len(views))\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n",
    "    cam_size = max(auto_cam_size(poses), 0.2)\n",
    "    for view_idx in range(dataset.num_views):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(idx * 255, (1 - idx) * 255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BlendedMVS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# BlendedMVS from Spann3r\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.blendedmvs_multiview import BlendedMVS_Multiview\n",
    "from fast3r.data.components.spann3r_datasets.blendedmvs import BlendMVS\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "# dataset = BlendedMVS_Multiview(split='train', ROOT=\"/path/to/dust3r_data/blendedmvs_processed\", resolution=512, num_views=4, window_size=6, num_samples_per_window=10, ordered=True, aug_crop=16)\n",
    "dataset = BlendMVS(split='train', num_frames=20, num_seq=200, ROOT='/path/to/dust3r_data/datasets_raw/BlendedMVS', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)])\n",
    "\n",
    "dataset.set_epoch(0)\n",
    "print(len(dataset))\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[(idx,0)]\n",
    "    # assert len(views) == dataset.num_views\n",
    "    print(len(views))\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n",
    "    cam_size = max(auto_cam_size(poses), 0.5)\n",
    "    for view_idx in range(len(views)):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(idx * 255, (1 - idx) * 255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DTU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# BlendedMVS from Spann3r\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.data.components.spann3r_datasets.dtu import DTU\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "dataset = DTU(split='test', ROOT='/path/to/dust3r_data/dtu_test_mvsnet_release', resolution=512, num_seq=1, full_video=True, kf_every=5)\n",
    "\n",
    "dataset.set_epoch(0)\n",
    "print(len(dataset))\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[(idx,0)]\n",
    "    # assert len(views) == dataset.num_views\n",
    "    print(len(views))\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n",
    "    cam_size = max(auto_cam_size(poses), 0.5)\n",
    "    for view_idx in range(len(views)):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(idx * 255, (1 - idx) * 255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# BlendedMVS from Spann3r\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"/path/to/fast3r/fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.data.components.spann3r_datasets.seven_scenes import SevenScenes\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "dataset = SevenScenes(split='test', ROOT='/path/to/dust3r_data/7_scenes_processed', resolution=512, num_seq=1, full_video=False, tuple_path=\"/path/to/dust3r_data/7_scenes_processed/\")\n",
    "\n",
    "dataset.set_epoch(0)\n",
    "print(len(dataset))\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[(idx,0)]\n",
    "    # assert len(views) == dataset.num_views\n",
    "    print(len(views))\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n",
    "    cam_size = max(auto_cam_size(poses), 0.5)\n",
    "    for view_idx in range(len(views)):\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(idx * 255, (1 - idx) * 255, 0),\n",
    "                        image=colors,\n",
    "                        cam_size=cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ASE_Multiview\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"../fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from fast3r.dust3r.datasets.ase_multiview import ASE_Multiview, ASE_Multiview_Simple\n",
    "\n",
    "from fast3r.dust3r.datasets.base.base_stereo_view_dataset import view_name\n",
    "from fast3r.dust3r.utils.image import rgb\n",
    "from fast3r.dust3r.viz import SceneViz, auto_cam_size\n",
    "from IPython.display import display\n",
    "\n",
    "# dataset = ASE_Multiview(\n",
    "#     split='train', data_scaling=0.5, num_views=30, window_size=30, num_samples_per_window=1, ROOT=\"/home/jianingy/research/fast3r/data/aria\", resolution=512, aug_crop=256\n",
    "# )\n",
    "dataset = ASE_Multiview_Simple(\n",
    "    split='train', data_scaling=0.5, num_views=30, ROOT=\"/home/jianingy/research/fast3r/data/aria\", resolution=512, aug_crop=256\n",
    ")\n",
    "\n",
    "for idx in np.random.permutation(len(dataset)):\n",
    "    views = dataset[idx]\n",
    "    assert len(views) == dataset.num_views\n",
    "    print(dataset.num_views)\n",
    "    print([view_name(view) for view in views])\n",
    "    viz = SceneViz()\n",
    "    poses = [views[view_idx]['camera_pose'] for view_idx in range(dataset.num_views)]\n",
    "    cam_size = max(auto_cam_size(poses), 0.5)\n",
    "    for view_idx in range(dataset.num_views):\n",
    "        height, width = views[view_idx][\"true_shape\"]\n",
    "        pts3d = views[view_idx]['pts3d']\n",
    "        valid_mask = views[view_idx]['valid_mask']\n",
    "        colors = rgb(views[view_idx]['img'])\n",
    "        viz.add_pointcloud(pts3d, colors, valid_mask)\n",
    "        viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n",
    "                        # focal=views[view_idx]['camera_intrinsics'][0, 0],\n",
    "                        color=(view_idx * 255, (1 - view_idx) * 255, 0),\n",
    "                        image=np.uint8((views[view_idx]['img'].swapaxes(1, 2) if width < height else views[view_idx]['img']).permute(1, 2, 0) * 127.5 + 127.5),\n",
    "                        cam_size=cam_size * 3 if width < height else cam_size)\n",
    "    display(viz.show())\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize the rgb\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Load the images from the views\n",
    "images = [rgb(view['img']) for view in views]\n",
    "\n",
    "# Plot the images\n",
    "fig, axes = plt.subplots(1, len(images), figsize=(40, 8))\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(images[i])\n",
    "    ax.axis('off')\n",
    "    # ax.set_title(f\"View {i}\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fast3r",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
